import numpy as np
import torch
import lightning
from tqdm import tqdm
import clip
from torch.utils.data import DataLoader, Dataset
from cemcd.models.linear import LinearModel

torch.set_float32_matmul_precision('high')

GROUP_SIZE = 50000
FOUNDATION_MODEL = "clip_vitl14"
currently_loaded_chunk = None
currently_loaded_chunk_idx = None
def getter(idx):
    global currently_loaded_chunk, currently_loaded_chunk_idx

    chunk_idx = idx // GROUP_SIZE
    within_chunk_idx = idx % GROUP_SIZE

    if chunk_idx != currently_loaded_chunk_idx:
        file_path = f"/datasets/imagenet/train_{FOUNDATION_MODEL}_features_group_{chunk_idx+1}.pt"
        currently_loaded_chunk = torch.load(file_path, weights_only=True)
        currently_loaded_chunk_idx = chunk_idx

    return currently_loaded_chunk[within_chunk_idx] / currently_loaded_chunk[within_chunk_idx].norm(dim=-1, keepdim=True)

class DiscoveredConceptDataset(Dataset):
    def __init__(self, labels):
        self.labels = labels

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        x = getter(idx)

        return x, self.labels[idx]

with open("20k.txt", "r") as f:
    lines = f.readlines()
    VOCAB = [line.strip() for line in lines]

ckpt_dir = "/checkpoints/clip"
clip_model, _ = clip.load("ViT-L/14", device="cuda", download_root=ckpt_dir)
clip_model.eval()

for top_level_concept in [27, 35]: #range(55):
    discovered_concepts = np.load(f"sampled_discovered_concepts/sampled_discovered_concepts_{top_level_concept+1}.npy")

    for discovered_concept_idx in range(discovered_concepts.shape[1]):
        labels = discovered_concepts[:, discovered_concept_idx]
        pos_weight = (labels == 0).sum() / (labels == 1).sum()
        naming_model = LinearModel(in_dim=768, pos_weight=pos_weight)
        train_dl = DataLoader(DiscoveredConceptDataset(labels), batch_size=256)

        trainer = lightning.Trainer(max_epochs=10)

        trainer.fit(naming_model, train_dl)

        naming_model.freeze()

        best_score = float("-inf")
        best_name = None

        for name in tqdm(VOCAB):
            text = clip.tokenize([name]).cuda()
            with torch.no_grad():
                text_features = clip_model.encode_text(text).float()
                text_features /= text_features.norm(dim=-1, keepdim=True)
                score = naming_model(text_features.to(naming_model.device)).item()
            
            if score > best_score:
                best_score = score
                best_name = name

        print(f"Top-level concept {top_level_concept+1}, discovered concept {discovered_concept_idx+1}: {best_name}")
        with open(f"sampled_discovered_concepts/naming_results_{top_level_concept+1}_10_epochs.txt", "a") as f:
            f.write(f"Discovered concept {discovered_concept_idx+1}: {best_name}\n")
